import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import MLPAgent, FastComNetwork, ComNetwork, DatasetModel ,MLP,MLPOracle
import random


# hyper-parameter
random.seed(42)
np.random.seed(42)
NUM_AGENTS = 16
NUM_ROUNDS = 200
T_RESTART = 20
DELTA = 1e-3  # 0th order parameter
LR = 0.01  # DOC²S learning rate
D = 0.005  # DOC²S radio
m_LR = 0.01  # MEDOL learning rate
m_D = 0.001  # MEDOL radio
dgfm_lr1 = 0.001  # dgfm learning rate
dgfm_lr2 = 0.001  # dgfm+ learning rate
R = 1  # Chebyshev rounds
p = 0.99  # control the spectrum gap
BATCH_SIZE = 128
DATASET_NAME = 'fashion-mnist'
HIDDEN_DIM = 128

random.seed(42)
np.random.seed(42)

# initialize dataset
dataset = DatasetModel(
    dsname=DATASET_NAME,
    num_agent=NUM_AGENTS,
    mb_size=BATCH_SIZE,
    max_sample=10000
)
oracle = MLPOracle(lam=1e-5, hidden_dim=HIDDEN_DIM)

# ring_matrix
def ring_matrix(n, p):
    W = np.zeros((n, n))
    for i in range(n):
        W[i, (i - 1) % n] = (1 - p) / 2
        W[i, i] = p
        W[i, (i + 1) % n] = (1 - p) / 2
    return W


def create_matrix(n):
    return np.full((n, n), 1 / n)


# DOC²S train loss
def train_DOC2S(agents, num_rounds, t_restart):
    network = FastComNetwork(create_matrix(NUM_AGENTS))
    losses = []

    avg_w = network.get_average_weight(agents)
    X_test, y_test = dataset.get_test_set()
    loss = oracle.get_fn_val(avg_w, X_test, y_test)
    losses.append(loss)
    print(f"round 0, {loss:.4f}")

    for k in range(num_rounds):
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()

        selected = np.random.randint(NUM_AGENTS)
        x_mb, y_mb = dataset.get_sample(selected)

        # get gradient
        for agent in agents:
            agent.get_grad_point()

        # renew weight
        new_weight = agents[selected].DOC2S_get_new_weight()
        agents[selected].set_weight(new_weight)

        # calculus gradient
        grad_point = agents[selected].get_grad_points()
        grad = oracle.get_gradients(grad_point, x_mb, y_mb)

        # renew action
        for i, agent in enumerate(agents):
            if i == selected:
                unprojected = agent.get_action() - agent.lr * grad
                norm = np.linalg.norm(unprojected)
                scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                agent.set_action(agent.NUM_AGENTS * scale * unprojected)
            else:
                agent.set_action(np.zeros_like(agent.get_action()))

        # communication
        network.propagate_actions(agents, R)
        network.propagate_weights(agents, R)

        # printing results
        if k % 10 == 0 and k > 0:
            avg_w = network.get_average_weight(agents)
            X_test, y_test = dataset.get_test_set()
            loss = oracle.get_fn_val(avg_w, X_test, y_test)  # 改为计算损失
            losses.append(loss)
            print(f"round {10*k}, {loss:.4f}")

    return losses


# MEDOL train loss
def train_MEDOL(agents, num_rounds, t_restart):
    network = ComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []

    avg_w = network.get_average_weight(agents)
    X_test, y_test = dataset.get_test_set()
    loss = oracle.get_fn_val(avg_w, X_test, y_test)
    print(f"round 0, {loss:.4f}")

    for k in range(num_rounds):
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()

        for m, agent in enumerate(agents):
            agent.get_grad_point()
            new_weight = agent.get_new_weight()
            agent.set_weight(new_weight)

            x_mb, y_mb = dataset.get_sample(m)
            grad_point = agent.get_grad_points()
            grad = oracle.get_gradients(grad_point, x_mb, y_mb)

            agent.action_grad_update(grad)

        network.propagate_actions(agents)
        network.propagate_weights(agents)

        if k % 10 == 0 and k > 0:
            avg_w = network.get_average_weight(agents)
            X_test, y_test = dataset.get_test_set()
            loss = oracle.get_fn_val(avg_w, X_test, y_test)
            losses.append(loss)
            print(f"轮次 {10*k}, 损失: {loss:.4f}")

    return losses

def train_DGFM(agents, num_rounds):
    network = ComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []
    max_fun = []
    for _ in range(num_rounds):
        # All clients update
        for agent in agents:
            x_mb, y_mb = dataset.get_sample(agent.id)
            w = agent.get_weight()
            grad = oracle.get_zo_grad(w, x_mb, y_mb, delta=DELTA)
            agent.update_y_grad(grad)

        # Communication
        network.propagate_dgfm_grad(agents)

        for agent in agents:
            agent.update_weight()

        network.propagate_weights(agents)

        # Record loss
        avg_weight = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_weight, *dataset.get_test_set()))

        # Record maximum loss
        agent_wi = []
        for agent in agents:
            wi = agent.get_weight()
            agent_wi.append(wi)
        agent_losses = []
        for idx, weight in enumerate(agent_wi):
            loss = oracle.get_fn_val(weight, *dataset.get_test_set())
            agent_losses.append(loss)
        max_loss = max(agent_losses)
        max_fun.append(max_loss)

    return losses

def train_DGFM_plus(agents, num_rounds, T_restart=10, mega_batch=512):
    network = FastComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []
    max_fun = []
    T = T_restart  # Restart period
    b = 64  # Regular batch size
    b_prime = 256  # Large batch size during restart

    for k in range(num_rounds):
        if k % T == 0:
            # === Restart Phase ===
            # 1. Compute initial gradient tracking variable v_i^{rT} = g_i(x_i^{rT}; S_i^{rT})
            for agent in agents:
                x_mb, y_mb = dataset.get_sample_DGFM(agent.id, mb_size=b_prime)
                grad = oracle.get_zo_grad(agent.get_weight(), x_mb, y_mb, delta=DELTA)
                agent.set_v(grad)  # Directly set v_i^{rT}

            # 2. Multiple rounds of Chebyshev accelerated communication (R=5)
            network.propagate_v(agents, R=5)

            # 3. Save current gradient to prev_grad
            for agent in agents:
                agent.save_prev_grad()

        # === Regular Iteration ===
        # Update all clients
        for agent in agents:
            x_mb, y_mb = dataset.get_sample_DGFM(agent.id, mb_size=b)
            grad_new = oracle.get_zo_grad(agent.get_weight(), x_mb, y_mb, delta=DELTA)
            agent.update_spider_grad(grad_new)  # Perform SPIDER update

        # Single round of communication to propagate gradient tracking variable v
        network.propagate_v(agents, R=1)

        # Weight update and communication
        for agent in agents:
            agent.update_weight()
        network.propagate_weights(agents, R=1)

        # Record loss
        avg_weight = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_weight, *dataset.get_test_set()))

        agent_losses = [oracle.get_fn_val(agent.get_weight(), *dataset.get_test_set())
                        for agent in agents]
        max_fun.append(max(agent_losses))

    return losses

doc2s_agents = [
    MLPAgent(
        input_dim=dataset.input_dim,
        hidden_dim=HIDDEN_DIM,
        id=i,
        lr=LR,
        D=D,
        NUM_AGENTS=NUM_AGENTS
    ) for i in range(NUM_AGENTS)
]

medol_agents = [
    MLPAgent(
        input_dim=dataset.input_dim,
        hidden_dim=HIDDEN_DIM,
        id=i,
        lr=m_LR,
        D=m_D,
        NUM_AGENTS=NUM_AGENTS
    ) for i in range(NUM_AGENTS)
]

dgfm_agents = [
    MLPAgent(
        input_dim=dataset.input_dim,
        hidden_dim=HIDDEN_DIM,
        id=i,
        lr=m_LR,
        D=dgfm_lr1,
        NUM_AGENTS=NUM_AGENTS
    ) for i in range(NUM_AGENTS)
]

dgfmp_agents = [
    MLPAgent(
        input_dim=dataset.input_dim,
        hidden_dim=HIDDEN_DIM,
        id=i,
        lr=m_LR,
        D=dgfm_lr2,
        NUM_AGENTS=NUM_AGENTS
    ) for i in range(NUM_AGENTS)
]

base_weight = doc2s_agents[0].get_weight().copy()
for agent_group in [doc2s_agents, medol_agents, dgfm_agents, dgfmp_agents]:
    for agent in agent_group:
        agent.set_weight(base_weight.copy())

doc2s_losses = train_DOC2S(doc2s_agents, NUM_ROUNDS, T_RESTART)

medol_losses = train_MEDOL(medol_agents, NUM_ROUNDS, T_RESTART)

dgfm_losses = train_MEDOL(dgfm_agents, NUM_ROUNDS, T_RESTART)

dgfmp_losses = train_MEDOL(dgfmp_agents, NUM_ROUNDS, T_RESTART)

# fig
plt.figure(figsize=(8.5, 7.5))
x_points = [0] + [i * 100 for i in range(1, len(doc2s_losses))]

plt.plot(x_points, doc2s_losses, label="$\mathrm{DOC^2S}$",color="black",  marker='o',markersize=11)
plt.plot(x_points, medol_losses, label="$\mathrm{ME{-}DOL}$",color="red",linestyle = "--",marker='*',markersize=11)
plt.plot(x_points, dgfm_losses, label=r"$\mathrm{DGFM}$", color="blue", marker='s', linestyle = "-.",markersize=11)
plt.plot(x_points, dgfmp_losses, label=r"$\mathrm{DGFM+}$", color="green", marker='^', markersize=11)
plt.xlabel(r"$\mathrm{Computation~Rounds}$",fontsize=31)
plt.ylabel(r"$\mathrm{Function~value}$",fontsize=31)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=27)
plt.tight_layout()
plt.grid(True, alpha=0.3)
plt.show()